Skip to content

Conversation

@ElizaWszola
Copy link
Contributor

@ElizaWszola ElizaWszola commented Oct 31, 2025

CUDA kernel for fused blockwise quant rms norm

Testing:
pytest tests/kernels/core/test_fused_quant_layernorm.py

TODO: E2E Tests, cleanup

Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
Signed-off-by: ElizaWszola <[email protected]>
@mergify mergify bot added the performance Performance-related issues label Nov 7, 2025
@ElizaWszola ElizaWszola changed the title [Performance] Blockwise quant RMS norm [Performance] Fused blockwise quant RMS norm Nov 7, 2025
@yewentao256
Copy link
Member

The optimization of this commit is beneficial:
Before

[-------------------------------------------- rms-norm-dynamic-per-token-quant --------------------------------------------]
                                                                  |  unfused_groupwise_fp8_impl  |  fused_groupwise_fp8_impl
1 threads: -----------------------------------------------------------------------------------------------------------------
      N 1 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]      |             31.4             |            29.4          
      N 1 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]     |             34.0             |            30.4          
      N 1 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]      |             31.3             |            29.6          
      N 1 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]     |             34.0             |            29.5          
      N 4 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]      |             30.1             |            29.5          
      N 4 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]     |             35.1             |            31.2          
      N 4 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]      |             32.4             |            32.5          
      N 4 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]     |             36.1             |            30.7          
      N 16 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]     |             31.6             |            31.4          
      N 16 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]    |             35.2             |            32.3          
      N 16 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]     |             32.8             |            32.2          
      N 16 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]    |             35.1             |            31.6          
      N 64 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]     |             31.8             |            31.5          
      N 64 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]    |             35.2             |            32.7          
      N 64 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]     |             31.8             |            31.6          
      N 64 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]    |             36.1             |            32.1          
      N 256 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]    |             32.8             |            32.3          
      N 256 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]   |             36.1             |            32.0          
      N 256 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]    |             32.6             |            32.3          
      N 256 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]   |             35.2             |            31.5          
      N 1024 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]   |             31.4             |            39.0          
      N 1024 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]  |             35.1             |            36.9          
      N 1024 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]   |             31.8             |            53.3          
      N 1024 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]  |             35.5             |            49.3   

now

[-------------------------------------------- rms-norm-dynamic-per-token-quant --------------------------------------------]
                                                                  |  unfused_groupwise_fp8_impl  |  fused_groupwise_fp8_impl
1 threads: -----------------------------------------------------------------------------------------------------------------
      N 1 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]      |             30.9             |            19.6          
      N 1 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]     |             36.5             |            19.4          
      N 1 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]      |             30.5             |            19.6          
      N 1 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]     |             36.5             |            19.6          
      N 4 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]      |             30.4             |            19.5          
      N 4 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]     |             34.2             |            19.3          
      N 4 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]      |             30.5             |            19.6          
      N 4 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]     |             34.2             |            19.4          
      N 16 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]     |             31.8             |            19.6          
      N 16 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]    |             36.4             |            19.5          
      N 16 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]     |             30.7             |            19.7          
      N 16 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]    |             36.5             |            19.7          
      N 64 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]     |             31.8             |            19.7          
      N 64 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]    |             36.5             |            19.6          
      N 64 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]     |             30.4             |            19.6          
      N 64 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]    |             34.3             |            19.5          
      N 256 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]    |             30.1             |            19.4          
      N 256 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]   |             34.4             |            19.8          
      N 256 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]    |             30.7             |            19.6          
      N 256 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]   |             34.2             |            19.5          
      N 1024 x D 1024 x R True x DT torch.bfloat16x GS [1, 128]   |             30.7             |            19.4          
      N 1024 x D 1024 x R False x DT torch.bfloat16x GS [1, 128]  |             34.4             |            19.4          
      N 1024 x D 5120 x R True x DT torch.bfloat16x GS [1, 128]   |             30.7             |            28.7          
      N 1024 x D 5120 x R False x DT torch.bfloat16x GS [1, 128]  |             34.5             |            28.7 

@mergify
Copy link

mergify bot commented Nov 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ElizaWszola.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

3 participants